package edu.northwestern.cbits.purple_robot_manager.models.trees.parsers; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import edu.northwestern.cbits.purple_robot_manager.models.trees.BranchNode; import edu.northwestern.cbits.purple_robot_manager.models.trees.BranchNode.Condition; import edu.northwestern.cbits.purple_robot_manager.models.trees.BranchNode.Operation; import edu.northwestern.cbits.purple_robot_manager.models.trees.LeafNode; import edu.northwestern.cbits.purple_robot_manager.models.trees.TreeNode; import edu.northwestern.cbits.purple_robot_manager.models.trees.TreeNode.TreeNodeException; /** * Implements a parser for the GraphViz format generated by Weka's J48 learner. * * Example model: * * <pre> * {@code * digraph J48Tree { * N0 [label=\"wifiaccesspointsprobe_current_ssid\" ] * N0->N1 [label=\"= ?\"] * N1 [label=\"alone (0.0)\" shape=box style=filled ] * N0->N2 [label=\"= home\"] * N2 [label=\"alone (8.41/2.0)\" shape=box style=filled ] * N0->N3 [label=\"= 0x\"] * N3 [label=\"robothealthprobe_cpu_usage\" ] * N3->N4 [label=\"<= 0.142857\"] * N4 [label=\"wifiaccesspointsprobe_access_point_count\" ] * N4->N5 [label=\"<= 17\"] * N5 [label=\"acquaintances (2.1/1.1)\" shape=box style=filled ] * N4->N6 [label=\"> 17\"] * N6 [label=\"strangers (2.1/0.1)\" shape=box style=filled ] * N3->N7 [label=\"> 0.142857\"] * N7 [label=\"alone (5.26/2.0)\" shape=box style=filled ] * N0->N8 [label=\"= blerg\"] * N8 [label=\"partner (1.05/0.05)\" shape=box style=filled ] * N0->N9 [label=\"= northwestern\"] * N9 [label=\"runningsoftwareproberunning_tasks_running_tasks_package_name\" ] * N9->N10 [label=\"= ?\"] * N10 [label=\"acquaintances (0.0)\" shape=box style=filled ] * N9->N11 [label=\"= comcbitsmobilyze_pro\"] * N11 [label=\"acquaintances (2.1/0.1)\" shape=box style=filled ] * N9->N12 [label=\"= comandroidlauncher\"] * N12 [label=\"alone (3.15/1.0)\" shape=box style=filled ] * N9->N13 [label=\"= edunorthwesterncbitspurple_robot_manager\"] * N13 [label=\"acquaintances (16.82/7.82)\" shape=box style=filled ] * } * </pre> */ public class WekaJ48TreeParser extends TreeNodeParser { private static final String NUM_INSTANCES = "num_instances"; private static final String NUM_INCORRECT = "num_incorrect"; ArrayList<String> _lines = new ArrayList<>(); /** * Traditional Main method for testing the classes from the desktop * environment. Runs a few tests of the tree above. * * @param args * Not used. */ public static void main(String[] args) { try { TreeNode node = TreeNodeParser .parseString("digraph J48Tree {\nN0 [label=\"wifiaccesspointsprobe_current_ssid\" ]\nN0->N1 [label=\"= ?\"]\nN1 [label=\"alone (0.0)\" shape=box style=filled ]\nN0->N2 [label=\"= home\"]\nN2 [label=\"alone (8.41/2.0)\" shape=box style=filled ]\nN0->N3 [label=\"= 0x\"]\nN3 [label=\"robothealthprobe_cpu_usage\" ]\nN3->N4 [label=\"<= 0.142857\"]\nN4 [label=\"wifiaccesspointsprobe_access_point_count\" ]\nN4->N5 [label=\"<= 17\"]\nN5 [label=\"acquaintances (2.1/1.1)\" shape=box style=filled ]\nN4->N6 [label=\"> 17\"]\nN6 [label=\"strangers (2.1/0.1)\" shape=box style=filled ]\nN3->N7 [label=\"> 0.142857\"]\nN7 [label=\"alone (5.26/2.0)\" shape=box style=filled ]\nN0->N8 [label=\"= blerg\"]\nN8 [label=\"partner (1.05/0.05)\" shape=box style=filled ]\nN0->N9 [label=\"= northwestern\"]\nN9 [label=\"runningsoftwareproberunning_tasks_running_tasks_package_name\" ]\nN9->N10 [label=\"= ?\"]\nN10 [label=\"acquaintances (0.0)\" shape=box style=filled ]\nN9->N11 [label=\"= comcbitsmobilyze_pro\"]\nN11 [label=\"acquaintances (2.1/0.1)\" shape=box style=filled ]\nN9->N12 [label=\"= comandroidlauncher\"]\nN12 [label=\"alone (3.15/1.0)\" shape=box style=filled ]\nN9->N13 [label=\"= edunorthwesterncbitspurple_robot_manager\"]\nN13 [label=\"acquaintances (16.82/7.82)\" shape=box style=filled ]\n}\n"); System.out.println(node.toString(0)); HashMap<String, Object> world = new HashMap<>(); Map<String, Object> prediction = node.fetchPrediction(world); System.out.println("Expect alone. Got " + prediction.get(LeafNode.PREDICTION) + " // " + prediction.get(LeafNode.ACCURACY) + "."); world.put("robothealthprobe_cpu_usage", 0.1); prediction = node.fetchPrediction(world); System.out.println("Expect alone. Got " + prediction.get(LeafNode.PREDICTION) + " // " + prediction.get(LeafNode.ACCURACY) + "."); world.put("wifiaccesspointsprobe_current_ssid", "blerg"); prediction = node.fetchPrediction(world); System.out.println("Expect partner. Got " + prediction.get(LeafNode.PREDICTION) + " // " + prediction.get(LeafNode.ACCURACY) + "."); world.put("wifiaccesspointsprobe_current_ssid", "0x"); world.put("wifiaccesspointsprobe_access_point_count", (double) 20); prediction = node.fetchPrediction(world); System.out.println("Expect strangers. Got " + prediction.get(LeafNode.PREDICTION) + " // " + prediction.get(LeafNode.ACCURACY) + "."); world.put("wifiaccesspointsprobe_current_ssid", "northwestern"); prediction = node.fetchPrediction(world); System.out.println("Expect acquaintances. Got " + prediction.get(LeafNode.PREDICTION) + " // " + prediction.get(LeafNode.ACCURACY) + "."); world.put("runningsoftwareproberunning_tasks_running_tasks_package_name", "comandroidlauncher"); prediction = node.fetchPrediction(world); System.out.println("Expect alone. Got " + prediction.get(LeafNode.PREDICTION) + " // " + prediction.get(LeafNode.ACCURACY) + "."); } catch (ParserNotFound | TreeNodeException e) { e.printStackTrace(); } } /** * Parses the provided content and returns the accompanying decision tree. * * @see edu.northwestern.cbits.purple_robot_manager.models.trees.parsers.TreeNodeParser#parse(java.lang.String) */ public TreeNode parse(String content) throws TreeNodeException { // Extract strings that have meaningful content. Examples: // N0 [label=\"wifiaccesspointsprobe_current_ssid\" ] // N0->N1 [label=\"= ?\"] for (String line : content.split("\\r?\\n")) { if (line.startsWith("digraph J48Tree")) { // Start } else if (line.startsWith("}")) { // End } else this._lines.add(line.trim()); } // The root of the Weka decision trees is always labelled "N0". return this.treeForNode("N0"); } /** * Recursively generates a decision tree based on the node ID. This function * scans the content lines containing information relevant to the node * specified and constructs the appropriate node (branch or leaf). * * @param id * Node ID to be turned into a TreeNode. * * @return TreeNode encoded with the relevant node details, including * descendants. * * @throws TreeNodeException * Thrown on errors constructing the node. */ private TreeNode treeForNode(String id) throws TreeNodeException { for (String line : this._lines) { if (line.startsWith(id + " [label=")) { if (line.contains("shape=box")) { // This line contains a leaf node. Build one... return this.leafNodeForLine(line); } else { // Branch node... BranchNode branch = new BranchNode(); // Split the line by quote tokens... String[] tokens = line.split("\\\""); String feature = tokens[1]; for (String edgeLine : this._lines) { // Is this an edge we care about? if (edgeLine.startsWith(id + "->")) { // It is. Replace irrelevant parts of the string // with tokenizable components. edgeLine = edgeLine.replace(id + "->", ""); edgeLine = edgeLine.replace(" [label=\"", "|"); edgeLine = edgeLine.replace(" ", "|"); edgeLine = edgeLine.replace("\"]", "|"); // Split on tokenizable component. String[] edgeTokens = edgeLine.split("\\|"); // Get the destination node of the this edge. String nextId = edgeTokens[0]; // Create a tree node for the destination node. TreeNode nextNode = this.treeForNode(nextId); // Get the test line components. String comparison = edgeTokens[1]; String value = edgeTokens[2]; if ("=".equals(comparison)) { if ("?".equals(value)) { // Represents missing data. Associate the // destination node with a low-priority // default catch-all. branch.addCondition(Operation.DEFAULT, feature, value, Condition.LOWEST_PRIORITY, nextNode); } else { // Looking for a specific value. Associate // the // node with a normal priority "equals" // condition. branch.addCondition(Operation.EQUALS, feature, value, Condition.DEFAULT_PRIORITY, nextNode); } } else { // Associate the node with the relevant numeric // comparison. if ("<=".equals(comparison)) branch.addCondition(Operation.LESS_THAN_OR_EQUAL_TO, feature, Double.valueOf(value), Condition.DEFAULT_PRIORITY, nextNode); else if (">".equals(comparison)) branch.addCondition(Operation.MORE_THAN, feature, Double.valueOf(value), Condition.DEFAULT_PRIORITY, nextNode); else if (">=".equals(comparison)) branch.addCondition(Operation.MORE_THAN_OR_EQUAL_TO, feature, Double.valueOf(value), Condition.DEFAULT_PRIORITY, nextNode); else if ("<".equals(comparison)) branch.addCondition(Operation.LESS_THAN, feature, Double.valueOf(value), Condition.DEFAULT_PRIORITY, nextNode); } } } return branch; } } } throw new TreeNode.TreeNodeException("Unable to find definition for node with ID '" + id + "'."); } /** * Constructs a LeafNode from a terminal node containing no children. * * @param line * Line representing the leaf node. * * @return LeafNode that returns the prediction found in the provided line. */ private TreeNode leafNodeForLine(String line) { // Extract the label components: prediction + accuracy information. String[] tokens = line.split("\\\""); String label = tokens[1]; String[] labelTokens = label.split(" \\("); HashMap<String, Object> prediction = new HashMap<>(); prediction.put(LeafNode.PREDICTION, labelTokens[0]); // Calculate the accuracy information. String remainder = labelTokens[1].substring(0, labelTokens[1].length() - 1); double accuracy = 1.0; double coverage = 0; double incorrect = 0; if (remainder.contains("/")) { // We have some instances that the tree has misclassified here in // the past. String[] remainderTokens = remainder.split("/"); coverage = Double.parseDouble(remainderTokens[0]); incorrect = Double.parseDouble(remainderTokens[1]); accuracy = (coverage - incorrect) / coverage; } else { // Node is 100% accurate. Let's just count the number of instances // covered by the node. coverage = Double.parseDouble(remainder); } // Add mandatory key and values. prediction.put(LeafNode.PREDICTION, labelTokens[0]); prediction.put(LeafNode.ACCURACY, accuracy); // Add additional format-specific metadata. prediction.put(WekaJ48TreeParser.NUM_INSTANCES, coverage); prediction.put(WekaJ48TreeParser.NUM_INCORRECT, incorrect); return new LeafNode(prediction); } }